FUNC_MAP = {
}


def get_FUNC_MAP():
    return FUNC_MAP


import importlib
import logging
import yaml
import os
logger = logging.getLogger(__name__)

from src.formula_base import MyMLError
def my_import(name):
    dot_idx = name.rindex('.')
    module_name = name[:dot_idx]
    class_name = name[dot_idx + 1:]
    mod = importlib.import_module(module_name)
    _cls = getattr(mod, class_name)
    return _cls


def loadOneFunc(f):
    func_cls = None
    with open(f, encoding='utf-8') as _f:
        fun_json = yaml.load(_f)
    name = os.path.basename(f).replace('.yml', '')
    formula_name = fun_json['formula_name'] 
    class_str = fun_json.get('class')
    if class_str is None:
        raise Exception(f'没有定义class字段: {f}')
    func_cls = my_import(class_str)
    logger.debug(f'load class {class_str} success,yml={f}')
    if formula_name in FUNC_MAP:
        raise Exception('func key 重复定义:yml={f}')
    FUNC_MAP[formula_name] = func_cls()
    return fun_json


def loadOneDir(_dir):
    func_list = []
    onlyfiles = [f for f in os.listdir(_dir) if f.endswith('.yml')]
    for f in onlyfiles:
        func_json = loadOneFunc(os.path.join(_dir, f))
        func_list.append(func_json)
    return func_list

def read():
    func_list = []
    FUNC_MAP.clear()
    func_list.extend(loadOneDir(os.path.join(os.path.dirname(__file__), 'src')))
    return func_list

print(read())
print(FUNC_MAP)

def infer(formula_name,params):
    if formula_name not in FUNC_MAP:
        raise MyMLError('1','公式名称不存在')
    return FUNC_MAP[formula_name].calc(params)